{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Analyzing the gradients of the SHT\n",
    "\n",
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "ename": "NotImplementedError",
     "evalue": "SGD not support complex64",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNotImplementedError\u001b[0m                       Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[13], line 11\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpaddle_harmonics\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mlegendre\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m legpoly, clm\n\u001b[1;32m      9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpaddle_harmonics\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m RealSHT, InverseRealSHT\n\u001b[0;32m---> 11\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSGD not support complex64\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
      "\u001b[0;31mNotImplementedError\u001b[0m: SGD not support complex64"
     ]
    }
   ],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "import paddle\n",
    "import paddle.nn as nn\n",
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "from paddle_harmonics.quadrature import legendre_gauss_weights, clenshaw_curtiss_weights\n",
    "from paddle_harmonics.legendre import legpoly, clm\n",
    "from paddle_harmonics import RealSHT, InverseRealSHT\n",
    "\n",
    "raise NotImplementedError(\"SGD not support complex64 parameter\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Problem setting\n",
    "\n",
    "We consider the simple problem of fitting the spectral coefficients $\\theta$ such that\n",
    "\n",
    "$$\n",
    "\\begin{align}\n",
    "\\mathcal{L}\n",
    "&= ||\\mathop{\\mathrm{ISHT}}[\\theta] - u^*||^2_{S^2} \\\\\n",
    "&\\approx \\sum_j \\omega_j (\\mathop{\\mathrm{ISHT}}[\\theta](x_j) - u^*(x_j))^2 \\\\\n",
    "&= (S \\, \\theta - u^*)^T \\mathop{\\mathrm{diag}}(\\omega) \\, (S \\, \\theta - u^*) \\\\\n",
    "&= L\n",
    "\\end{align}\n",
    "$$\n",
    "\n",
    "is minimized.\n",
    "\n",
    "The Vandermonde matrix $S$, which is characterized by $\\mathop{\\mathrm{ISHT}}[\\theta] = S \\theta$ realizes the action of the discrete SHT.\n",
    "\n",
    "The necessary condition for a minimizer of $L$ is\n",
    "\n",
    "$$\n",
    "\\begin{align}\n",
    "& \\nabla_\\theta L = S^T \\mathop{\\mathrm{diag}}(\\omega) \\, (S \\, \\theta - u^*) = 0 \\\\\n",
    "\\Leftrightarrow \\quad & S^T \\mathop{\\mathrm{diag}}(\\omega) \\, S \\; \\theta = S^T \\mathop{\\mathrm{diag}}(\\omega) \\, u^*.\n",
    "\\end{align}\n",
    "$$\n",
    "\n",
    "On the Gaussian grid, "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nlat = 64\n",
    "nlon = 2*nlat\n",
    "grid = \"equiangular\"\n",
    "\n",
    "# for quadrature and plotting\n",
    "if grid == \"legendre-gauss\":\n",
    "    lmax = mmax = nlat\n",
    "    xq, wq = legendre_gauss_weights(nlat)\n",
    "elif grid ==\"equiangular\":\n",
    "    lmax = mmax = nlat//2\n",
    "    xq, wq = clenshaw_curtiss_weights(nlat)\n",
    "\n",
    "sht = RealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid)\n",
    "isht = InverseRealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid)\n",
    "\n",
    "lat = np.arccos(xq)\n",
    "omega = math.pi * paddle.to_tensor(wq).astype(\"float32\") / nlat\n",
    "omega = omega.reshape(-1, 1)\n",
    "\n",
    "nlon*omega.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget https://astropedia.astrogeology.usgs.gov/download/Mars/GlobalSurveyor/MOLA/thumbs/Mars_MGS_MOLA_DEM_mosaic_global_1024.jpg -O ./data/mola_topo.jpg\n",
    "\n",
    "import imageio.v3 as iio\n",
    "\n",
    "img = iio.imread('./data/mola_topo.jpg')\n",
    "#convert to grayscale\n",
    "data = np.dot(img[...,:3]/255, [0.299, 0.587, 0.114])\n",
    "# interpolate onto 512x1024 grid:\n",
    "data = nn.functional.interpolate(paddle.to_tensor(data).unsqueeze(0).unsqueeze(0), size=(nlat,nlon)).squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from plotting import plot_sphere\n",
    "\n",
    "plot_sphere(data, cmap=\"turbo\", colorbar=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 1.0\n",
    "tmp = paddle.randn([lmax, lmax], dtype=paddle.complex64)\n",
    "theta = tmp\n",
    "theta.stop_gradient = False\n",
    "# theta = paddle.create_parameter(shape=tmp.shape, dtype=tmp.dtype, default_initializer=paddle.nn.initializer.Assign(tmp))\n",
    "optim = paddle.optimizer.SGD(parameters=[theta], learning_rate=lr)\n",
    "\n",
    "for iter in range(40):\n",
    "    optim.clear_grad()\n",
    "    loss = paddle.sum(0.5*omega*(isht(theta) - data)**2)\n",
    "    loss.backward()\n",
    "\n",
    "    # action of the Hessian\n",
    "    with paddle.no_grad():\n",
    "        for m in range(1,mmax):\n",
    "            theta.grad[:,m].multiply_(paddle.to_tensor(0.5))\n",
    "\n",
    "    optim.step()\n",
    "\n",
    "    print(f\"iter: {iter}, loss: {loss}\")\n",
    "\n",
    "# for iter in range(40):\n",
    "#     optim.zero_grad(set_to_none=True)\n",
    "#     loss = torch.sum(0.5*omega*(isht(theta) - data)**2)\n",
    "#     loss.backward()\n",
    "\n",
    "#     # action of the Hessian\n",
    "#     with torch.no_grad():\n",
    "#         for m in range(1,mmax):\n",
    "#             theta.grad[:,m].mul_(0.5)\n",
    "    \n",
    "#     with torch.no_grad():\n",
    "#         theta.add_(theta.grad, alpha=-lr)\n",
    "\n",
    "#     print(f\"iter: {iter}, loss: {loss}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "what's the best possible loss? $\\theta^* = (S^T \\mathop{\\mathrm{diag}}(\\omega) \\, S)^{-1} S^T \\mathop{\\mathrm{diag}}(\\omega) u^* = \\mathop{\\mathrm{SHT}}[u^*]$ gives us the global minimizer for this problem."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(layout='constrained', figsize=(20, 12))\n",
    "subfigs = fig.subfigures(2, 3)\n",
    "\n",
    "# spectral fitting\n",
    "plot_sphere(isht(theta).detach(), fig=subfigs[0,0], cmap=\"turbo\", colorbar=True, title=\"Fit\")\n",
    "plot_sphere(data, fig=subfigs[0,1], cmap=\"turbo\", colorbar=True, title=\"Ground truth\")\n",
    "plot_sphere((isht(theta) - data).detach(), fig=subfigs[0,2], cmap=\"turbo\", colorbar=True, title=\"residual\")\n",
    "\n",
    "# sht(u)\n",
    "plot_sphere(isht(sht(data)).detach(), fig=subfigs[1,0], cmap=\"turbo\", colorbar=True, title=\"isht(sht(u))\")\n",
    "plot_sphere(data, fig=subfigs[1,1], cmap=\"turbo\", colorbar=True, title=\"Ground truth\")\n",
    "plot_sphere((isht(sht(data)) - data).detach(), fig=subfigs[1,2], cmap=\"turbo\", colorbar=True, title=\"residual\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
